{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp models.XResNet1dPlus"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# XResNet1dPlus"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> This is a modified version of fastai's XResNet model in github"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "from tsai.imports import *\n",
    "from tsai.models.layers import *\n",
    "from tsai.models.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class XResNet1dPlus(nn.Sequential):\n",
    "    @delegates(ResBlock1dPlus)\n",
    "    def __init__(self, block=ResBlock1dPlus, expansion=4, layers=[3,4,6,3], fc_dropout=0.0, c_in=3, c_out=None, n_out=1000, seq_len=None, stem_szs=(32,32,64),\n",
    "                 widen=1.0, sa=False, act_cls=defaults.activation, ks=3, stride=2, coord=False, custom_head=None, block_szs_base=(64,128,256,512), **kwargs):\n",
    "\n",
    "        store_attr('block,expansion,act_cls,ks')\n",
    "        n_out = c_out or n_out # added for compatibility\n",
    "        if ks % 2 == 0: raise Exception('kernel size has to be odd!')\n",
    "        stem_szs = [c_in, *stem_szs]\n",
    "        stem = [ConvBlock(stem_szs[i], stem_szs[i+1], ks=ks, coord=coord, stride=stride if i==0 else 1,\n",
    "                          act=act_cls)\n",
    "                for i in range(3)]\n",
    "\n",
    "        block_szs = [int(o*widen) for o in (list(block_szs_base) + [int(block_szs_base[-1]/2)]*(len(layers)-4))]\n",
    "        block_szs = [64//expansion] + block_szs\n",
    "        blocks    = self._make_blocks(layers, block_szs, sa, coord, stride, **kwargs)\n",
    "        backbone = nn.Sequential(*stem, MaxPool(ks=ks, stride=stride, padding=ks//2, ndim=1), *blocks)\n",
    "        self.head_nf = block_szs[-1]*expansion\n",
    "        if custom_head is not None: \n",
    "            if isinstance(custom_head, nn.Module): head = custom_head\n",
    "            else: head = custom_head(self.head_nf, n_out, seq_len)\n",
    "        else: head = nn.Sequential(AdaptiveAvgPool(sz=1, ndim=1), Flatten(), nn.Dropout(fc_dropout), nn.Linear(block_szs[-1]*expansion, n_out))\n",
    "        super().__init__(OrderedDict([('backbone', backbone), ('head', head)]))\n",
    "        self._init_cnn(self)\n",
    "\n",
    "    def _make_blocks(self, layers, block_szs, sa, coord, stride, **kwargs):\n",
    "        return [self._make_layer(ni=block_szs[i], nf=block_szs[i+1], blocks=l, coord=coord, \n",
    "                                 stride=1 if i==0 else stride, sa=sa and i==len(layers)-4, **kwargs)\n",
    "                for i,l in enumerate(layers)]\n",
    "\n",
    "    def _make_layer(self, ni, nf, blocks, coord, stride, sa, **kwargs):\n",
    "        return nn.Sequential(\n",
    "            *[self.block(self.expansion, ni if i==0 else nf, nf, coord=coord, stride=stride if i==0 else 1,\n",
    "                      sa=sa and i==(blocks-1), act_cls=self.act_cls, ks=self.ks, **kwargs)\n",
    "              for i in range(blocks)])\n",
    "    \n",
    "    def _init_cnn(self, m):\n",
    "        if getattr(self, 'bias', None) is not None: nn.init.constant_(self.bias, 0)\n",
    "        if isinstance(self, (nn.Conv1d,nn.Conv2d,nn.Conv3d,nn.Linear)): nn.init.kaiming_normal_(self.weight)\n",
    "        for l in m.children(): self._init_cnn(l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "def _xresnetplus(expansion, layers, c_in, c_out, seq_len=None, **kwargs):\n",
    "    return XResNet1dPlus(ResBlock1dPlus, expansion, layers, c_in=c_in, c_out=c_out, seq_len=seq_len, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "@delegates(ResBlock)\n",
    "def xresnet1d18plus (c_in, c_out, seq_len=None, act=nn.ReLU, **kwargs): \n",
    "    return _xresnetplus(1, [2, 2,  2, 2], c_in, c_out, seq_len=seq_len, act_cls=act, **kwargs)\n",
    "@delegates(ResBlock)\n",
    "def xresnet1d34plus (c_in, c_out, seq_len=None, act=nn.ReLU, **kwargs): \n",
    "    return _xresnetplus(1, [3, 4,  6, 3], c_in, c_out, seq_len=seq_len, act_cls=act, **kwargs)\n",
    "@delegates(ResBlock)\n",
    "def xresnet1d50plus (c_in, c_out, seq_len=None, act=nn.ReLU, **kwargs): \n",
    "    return _xresnetplus(4, [3, 4,  6, 3], c_in, c_out, seq_len=seq_len, act_cls=act, **kwargs)\n",
    "@delegates(ResBlock)\n",
    "def xresnet1d101plus (c_in, c_out,seq_len=None,  act=nn.ReLU, **kwargs): \n",
    "    return _xresnetplus(4, [3, 4, 23, 3], c_in, c_out, seq_len=seq_len, act_cls=act, **kwargs)\n",
    "@delegates(ResBlock)\n",
    "def xresnet1d152plus (c_in, c_out, seq_len=None, act=nn.ReLU, **kwargs): \n",
    "    return _xresnetplus(4, [3, 8, 36, 3], c_in, c_out, seq_len=seq_len, act_cls=act, **kwargs)\n",
    "@delegates(ResBlock)\n",
    "def xresnet1d18_deepplus (c_in, c_out, seq_len=None, act=nn.ReLU, **kwargs): \n",
    "    return _xresnetplus(1, [2,2,2,2,1,1], c_in, c_out, seq_len=seq_len, act_cls=act, **kwargs)\n",
    "@delegates(ResBlock)\n",
    "def xresnet1d34_deepplus (c_in, c_out, seq_len=None, act=nn.ReLU, **kwargs): \n",
    "    return _xresnetplus(1, [3,4,6,3,1,1], c_in, c_out, seq_len=seq_len, act_cls=act, **kwargs)\n",
    "@delegates(ResBlock)\n",
    "def xresnet1d50_deepplus (c_in, c_out, seq_len=None, act=nn.ReLU, **kwargs): \n",
    "    return _xresnetplus(4, [3,4,6,3,1,1], c_in, c_out, seq_len=seq_len, act_cls=act, **kwargs)\n",
    "@delegates(ResBlock)\n",
    "def xresnet1d18_deeperplus (c_in, c_out, seq_len=None, act=nn.ReLU, **kwargs): \n",
    "    return _xresnetplus(1, [2,2,1,1,1,1,1,1], c_in, c_out, seq_len=seq_len, act_cls=act, **kwargs)\n",
    "@delegates(ResBlock)\n",
    "def xresnet1d34_deeperplus (c_in, c_out, seq_len=None, act=nn.ReLU, **kwargs): \n",
    "    return _xresnetplus(1, [3,4,6,3,1,1,1,1], c_in, c_out, seq_len=seq_len, act_cls=act, **kwargs)\n",
    "@delegates(ResBlock)\n",
    "def xresnet1d50_deeperplus (c_in, c_out, seq_len=None, act=nn.ReLU, **kwargs): \n",
    "    return _xresnetplus(4, [3,4,6,3,1,1,1,1], c_in, c_out, seq_len=seq_len, act_cls=act, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "block <class 'tsai.models.layers.ResBlock1dPlus'> expansion 1 layers [2, 2, 2, 2]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TensorBase([[ 0.1829,  0.3597],\n",
       "            [ 0.0274, -0.1443],\n",
       "            [ 0.0240, -0.2374],\n",
       "            [-0.1323, -0.6574],\n",
       "            [ 0.1481, -0.1438],\n",
       "            [ 0.2410, -0.1225],\n",
       "            [-0.1186, -0.1978],\n",
       "            [-0.0640, -0.4547],\n",
       "            [-0.0229, -0.3214],\n",
       "            [ 0.2336, -0.4466],\n",
       "            [-0.1843, -0.0934],\n",
       "            [-0.0416,  0.1997],\n",
       "            [-0.0109, -0.0253],\n",
       "            [ 0.3014, -0.2193],\n",
       "            [ 0.0966,  0.0602],\n",
       "            [ 0.2364,  0.2209],\n",
       "            [-0.1437, -0.1476],\n",
       "            [ 0.0070, -0.2900],\n",
       "            [ 0.2807,  0.4797],\n",
       "            [-0.2386, -0.1563],\n",
       "            [ 0.1620, -0.2285],\n",
       "            [ 0.0479, -0.2348],\n",
       "            [ 0.1573, -0.4420],\n",
       "            [-0.5469,  0.1512],\n",
       "            [ 0.0243, -0.1806],\n",
       "            [ 0.3396,  0.1434],\n",
       "            [ 0.0666, -0.1644],\n",
       "            [ 0.3286, -0.5637],\n",
       "            [ 0.0993, -0.6281],\n",
       "            [-0.1068, -0.0763],\n",
       "            [-0.2713,  0.1946],\n",
       "            [-0.1416, -0.4043]], grad_fn=<AliasBackward0>)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = xresnet1d18plus(3, 2, coord=True)\n",
    "x = torch.rand(32, 3, 50)\n",
    "net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 xresnet1d18plus\n",
      "block <class 'tsai.models.layers.ResBlock1dPlus'> expansion 1 layers [2, 2, 2, 2]\n",
      "1 xresnet1d34plus\n",
      "block <class 'tsai.models.layers.ResBlock1dPlus'> expansion 1 layers [3, 4, 6, 3]\n",
      "2 xresnet1d50plus\n",
      "block <class 'tsai.models.layers.ResBlock1dPlus'> expansion 4 layers [3, 4, 6, 3]\n",
      "3 xresnet1d18_deepplus\n",
      "block <class 'tsai.models.layers.ResBlock1dPlus'> expansion 1 layers [2, 2, 2, 2, 1, 1]\n",
      "4 xresnet1d34_deepplus\n",
      "block <class 'tsai.models.layers.ResBlock1dPlus'> expansion 1 layers [3, 4, 6, 3, 1, 1]\n",
      "5 xresnet1d50_deepplus\n",
      "block <class 'tsai.models.layers.ResBlock1dPlus'> expansion 4 layers [3, 4, 6, 3, 1, 1]\n",
      "6 xresnet1d18_deeperplus\n",
      "block <class 'tsai.models.layers.ResBlock1dPlus'> expansion 1 layers [2, 2, 1, 1, 1, 1, 1, 1]\n",
      "7 xresnet1d34_deeperplus\n",
      "block <class 'tsai.models.layers.ResBlock1dPlus'> expansion 1 layers [3, 4, 6, 3, 1, 1, 1, 1]\n",
      "8 xresnet1d50_deeperplus\n",
      "block <class 'tsai.models.layers.ResBlock1dPlus'> expansion 4 layers [3, 4, 6, 3, 1, 1, 1, 1]\n"
     ]
    }
   ],
   "source": [
    "bs, c_in, seq_len = 2, 4, 32\n",
    "c_out = 2\n",
    "x = torch.rand(bs, c_in, seq_len)\n",
    "archs = [\n",
    "    xresnet1d18plus, xresnet1d34plus, xresnet1d50plus, \n",
    "    xresnet1d18_deepplus, xresnet1d34_deepplus, xresnet1d50_deepplus, xresnet1d18_deeperplus,\n",
    "    xresnet1d34_deeperplus, xresnet1d50_deeperplus\n",
    "#     # Long test\n",
    "#     xresnet1d101, xresnet1d152,\n",
    "]\n",
    "for i, arch in enumerate(archs):\n",
    "    print(i, arch.__name__)\n",
    "    test_eq(arch(c_in, c_out, sa=True, act=Mish, coord=True)(x).shape, (bs, c_out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "block <class 'tsai.models.layers.ResBlock1dPlus'> expansion 1 layers [3, 4, 6, 3]\n"
     ]
    }
   ],
   "source": [
    "m = xresnet1d34plus(4, 2, act=Mish)\n",
    "test_eq(len(get_layers(m, is_bn)), 38)\n",
    "test_eq(check_weight(m, is_bn)[0].sum(), 22)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": "IPython.notebook.save_checkpoint();",
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/nacho/notebooks/tsai/nbs/059_models.XResNet1dPlus.ipynb saved at 2023-03-26 16:07:44\n",
      "Correct notebook to script conversion! 😃\n",
      "Sunday 26/03/23 16:07:46 CEST\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,UklGRvQHAABXQVZFZm10IBAAAAABAAEAECcAACBOAAACABAAZGF0YdAHAAAAAPF/iPh/gOoOon6w6ayCoR2ZeyfbjobxK+F2Hs0XjKc5i3DGvzaTlEaraE+zz5uLUl9f46fHpWJdxVSrnfmw8mYEScqUP70cb0Q8X41uysJ1si6Eh1jYzXp9IE2DzOYsftYRyoCY9dJ/8QICgIcEun8D9PmAaBPlfT7lq4MFIlh61tYPiCswIHX+yBaOqT1QbuW7qpVQSv9lu6+xnvRVSlyopAypbGBTUdSalrSTaUBFYpInwUpxOzhti5TOdndyKhCGrdwAfBUcXIJB69p+Vw1egB76+n9q/h6ADglbf4LvnIHfF/981ODThF4m8HiS0riJVjQ6c+/EOZCYQfJrGrhBmPVNMmNArLKhQlkXWYqhbaxXY8ZNHphLuBJsZUEckCTFVHMgNKGJytIDeSUmw4QN4Qx9pReTgb3vYX/TCBuApf75f+P5Y4CRDdN+B+tngk8c8nt03CKGqipgd13OhotwOC5x9MCAknFFcmlmtPmagFFFYOCo0qRzXMhVi57pryNmIEqJlRi8bm52PfuNM8k4dfQv+4cO12l6zCGdg3jl730uE/KAPvS+f0wEAoAsA89/XfXQgBESIn6S5luDtiC8eh/YmIfpLqt1OMp5jXg8/24MveqUNUnPZsqw0Z3yVDldnaUOqIZfXlKrm36zzWhjRhaT+r+ncHI5/otUzfd2uSt7hl/bqXtoHaCC6+mqfrAOeoDD+PJ/xf8RgLMHfH/b8GeBihZIfSXidoQSJWB52NM1iRkzz3MkxpKPbUCrbDu5d5fgTAxkSK3JoEhYD1p2omere2LZTuqYLbdWa49Cx5Dww7tyXDUnioXRkHhwJyKFvd/AfPoYy4Fl7j1/LQorgEr9/X89+0qAOAwAf13sJoL8Gkd8wt25hWIp3Heez/eKODfPcSPCzpFNRDVqf7UlmnNQKGHgqd+jgVvJVm2f265QZTpLS5byur1tpT6ajvrHq3Q2MXWIxtUCehoj8YMk5LB9hRQegeTypn+nBQWA0QHgf7f2q4C5EFt+5ucOg2YfHXtq2SSHpS0ydnTL4IxFO6pvNb4ulBdInWfcsfSc7VMmXpSmE6eeXmZThJxpsgRohEfOk86+AHCoOpOMFsx1dv8s6oYT2k17uR7ngpXod34IEJqAaPfnfyABCIBZBpl/NPI2gTQVjX134x2ExSPMeR7VtYjZMWJ0W8ftjkA/YW1durCWykvjZFKu4p9LVwVbZKNkqpxh6U+6mRC2mGq2Q3SRvsIgcpc2sIpD0Bp4uiiFhW3ecXxOGgaCDe0Vf4cLPoDv+/5/mfw1gN4KKX+17emBqBmYfBHfVYUZKFR44NBtiv41bHJUwx+RJkP1apu2VJlkTwli4qrwoo1ax1dToNCtemRSTBGXz7kJbdM/PY/Dxht0dTLziH7Ul3loJEiE0uJsfdsVTYGL8Yt/AgcMgHYA7X8S+IqAYA+QfjzpxIIVHnp7tdqzhmAstXaxzEqMETpScGC/dJP3Rmdo8LIZnOVSEF+Opxumsl1sVF+dVrE5Z6NIiZSkvVdv2zsqjdnK8HVDLlyHyNjuegogM4NA5z9+YRG9gA722H97AgOA/gSyf43zCIHdE899yuTIg3ciNXpm1jmImTDwdJPITI4RPhRugbvslbFKt2Vfr/6eTFb4W1WkY6m6YPdQjJr2tNZp3EQlko7BgXHRNz2LAc+gdwMq7IUf3R58ohtFgrbr6n7hDFWAlPr8f/T9I4CECU9/De+vgVQY5nxh4POEzybJeCTS5YnCNAZzhsRzkP1Bsmu4t4aYU07nYuerA6KWWcJYO6HHrKJjaE3Zl624UWz/QOOPjcWHc7QzdIk40yl5tCWjhIDhJX0xF4CBMvBsf10IF4Ac//Z/bPlsgAcOwn6S6n6CwxzUewLcRoYaKzV38M23i9o493CNwL6S1UUuaQe0QpvbUfdfiqglpcRccFU+nkWwambASUiVfLyqbg49xY2eyWh1hy/Sh37XjHpaIYKD7OUEfrgS5IC09MV/1gMBgKMDyH/n9N6AhhINfh7mdoMoIZt6r9fAh1cvfHXNya6N4DzDbqi8K5WWSYlmbbAdnkpV6FxJpWSo1V8DUmGb3rMRaQBG2JJgwN9wCDnNi8HNI3dKK1aG0dvHe/UciIJf6rt+Og5wgDn59X9P/xWAKQhxf2XweYH+FjB9suGVhIMlOnlo02GJhTOdc7vFyo/TQGxs2Li7lz9NwmPurBihnVi7WSWiwKvGYntOpJiOt5drKUKMkFnE8HLxNPmJ9NG4eP8mAYUv4Np8hhi3gdruSX+3CSWAwP38f8f6UoCuDPF+6Os8gnAbKnxQ3d2F0imydzDPKIuiN5lxu8EKkrFE82kftW2az1DbYImpMqTUW3FWIJ83r5hl2koJlla7+m0+PmSOZcjcdMgwS4g11iZ6qCLUg5jkxn0QFA6BWvOvfzEFBIBHAtp/Qfa3gC4RSH5y5yeD2B/8evnYS4cULgR2CMsUja47cG/QvW6UeEhXZ3+xP51GVNVdP6Zpp+1eDFM5nMeySWghR4+TNL85cD46YIyCzKJ2kCzEhoTabXtGHs+CCemJfpMPjoDe9+t/qQALgM8Gj3++8UaBqRV2fQTjO4Q3JKd5r9TgiEYyMHTxxiWPpz8jbfq585YpTJpk960xoKFXsVoTo7yq6GGMTw==\" type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#|eval: false\n",
    "#|hide\n",
    "from tsai.export import get_nb_name; nb_name = get_nb_name(locals())\n",
    "from tsai.imports import create_scripts; create_scripts(nb_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
